# coding:utf-8

import sys
from src.constant.file_and_path_constant import FileAndPathConstant

sys.path.append(FileAndPathConstant.System_Drive + '\github-repository\hades\dev-project\jormungandr')

from src.manager.oracle_manager import OracleManager
from src.manager.lstm_manager import LstmManager
from src.constant.oracle import Oracle
from src.manager.log_manager import LogManager
from src.manager.nn_manager import NnManager
from pandas.core.frame import DataFrame
import numpy as np

Logger = LogManager.get_logger(__name__)


class NnHandler:
    """
    神经网络处理器
    """

    def __init__(self):
        """
        构造函数，初始化oracle_manager和cursor对象
        """
        self.oracle_manager = OracleManager(Oracle.Username, Oracle.Password, Oracle.Url)
        self.oracle_manager.connect()
        self.cursor = self.oracle_manager.get_cursor()

    def close_cursor_and_connect(self):
        """
        关闭游标和数据库连接
        :return:
        """
        self.cursor.close()
        self.oracle_manager.connect_close()

    def test(self):
        # 前向传播神经网络（自定义的数组和矩阵）
        nn_manager = NnManager()
        nn_manager.fpnn_simple()

        # 测试变量和张量
        nn_manager = NnManager()
        nn_manager.test_variable_and_tensor()

    def load_train_data(self):
        """
        返回训练数据
        获取股票数据，格式为：（21个特征值，1个标签）
        :return:
        """
        # 特征值二维数组
        train_x = []
        # 标签一维数组
        train_label = []
        # 所有日期
        self.cursor.execute("select t.date_ from stock_transaction_data_all t "
                            "where t.code_='000004' "
                            "and t.date_ between to_date('2010-01-01','yyyy-mm-dd') and to_date('2020-12-31','yyyy-mm-dd') "
                            "order by t.date_ desc")
        date_tuple_list = self.cursor.fetchall()
        for date_tuple in date_tuple_list:
            # 收盘价一维数组
            close_price_list = []
            # 日期
            date = date_tuple[0]
            # 查找日期date之后的21天的收盘价，降序排列，并存储在二维数组中
            self.cursor.execute("select close_price from ("
                                "select * from stock_transaction_data_all t "
                                "where t.code_='000004' and t.date_<to_date('" + date.strftime(
                "%Y-%m-%d") + "','yyyy-mm-dd') order by t.date_ desc) "
                              "where rownum<=21")
            train_x_close_price_tuple_list = self.cursor.fetchall()
            for close_price_tuple in train_x_close_price_tuple_list:
                close_price = close_price_tuple[0]
                close_price_list.append(close_price)
            train_x.append(close_price_list)
            # 查找日期date当天的收盘价，并存储在一维数组中
            self.cursor.execute("select t.close_price from stock_transaction_data_all t "
                                "where t.code_='000004' and t.date_=to_date('" + date.strftime("%Y-%m-%d") + "','yyyy-mm-dd')")
            train_label_close_price_tuple_list = self.cursor.fetchall()
            train_label.append(train_label_close_price_tuple_list[0][0])
        return train_x, train_label

    def do(self):
        # 返回训练数据。获取股票数据，格式为：（21个特征值，1个标签）
        train_x, train_label = self.load_train_data()

        # 前向传播神经网络（使用tensorflow的变量和常量）
        nn_manager = NnManager()
        nn_manager.fpnn_tensorflow(train_x, train_label)
